import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import scipy
from scipy.spatial.distance import cosine
import numpy as np
from torchvision.datasets import CocoCaptions
import open_clip
from scipy.io import loadmat


transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
#for mean and std, I use the mean and std of imageNet settings, while COCO don't provide the mean the std
dataerror = loadmat('./data/mat/clipmem/errorlist.mat')
errorlist = dataerror['a'].tolist()
datacanary = loadmat("./data/mat/clipmem/canarylist.mat")
canarylist = datacanary['a'].tolist()
coco_root = './data/datasets/coco/'
ann_file = coco_root + 'annotations/captions_train2014.json'
img_dir = coco_root + 'train2014/'
dataset = CocoCaptions(img_dir, ann_file, transform)
errorset = torch.utils.data.Subset(dataset, errorlist)
canaryset = torch.utils.data.Subset(dataset, canarylist)
errordataloader = DataLoader(errorset, batch_size=1, shuffle=True, num_workers=4)
canarydataloader = DataLoader(canaryset, batch_size=1, shuffle=True, num_workers=4)
modelf = torch.load('./data/model/clip/mineclip/100f.pt')
modelg = torch.load('./data/model/clip/mineclip/100g.pt')
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
tencoderf = modelf.encode_text
iencoderf = modelf.encode_image
tencoderf = tencoderf.to(device)
iencoderf = iencoderf.to(device)
tencoderg = modelg.encode_text
iencoderg = modelg.encode_image
tencoderg = tencoderg.to(device)
iencoderg = iencoderg.to(device)

errorimgembf = []
errorimgembg = []
errortextembf = []
errortextembg = []
clipmem = []

for images, captions in errordataloader:
    images = images.to(device)
    captions = captions.to(device)
    errortextembf.append(tencoderf(open_clip.tokenize(captions[0])).reshape(-1).cpu().detach().numpy())
    errortextembg.append(tencoderg(open_clip.tokenize(captions[0])).reshape(-1).cpu().detach().numpy())
    errorimgembf.append(iencoderf(images).reshape(-1).cpu().detach().numpy())
    errorimgembg.append(iencoderg(images).reshape(-1).cpu().detach().numpy())

for images, captions in canarydataloader:
    images = images.to(device)
    captions = captions.to(device)
    #for moddel f
    wholetextembf = []
    tsimilaritiesf = []
    isimilaritiesf = []
    imageembf = iencoderf(images).reshape(-1).cpu().detach().numpy()
    textembf = tencoderf(open_clip.tokenize(captions[0])).reshape(-1).cpu().detach().numpy()
    wholetextembf.append(textembf)
    wholetextembf.append(errortextembf)
    for tvector in wholetextembf:
        tsimilarityf = 1 - cosine(imageembf, tvector)
        tsimilaritiesf.append(tsimilarityf)
    tmeanf = np.mean(np.array(tsimilaritiesf[1:]))
    for ivector in errorimgembf:
        isimilarityf = 1 - cosine(ivector, textembf)
        isimilaritiesf.append(isimilarityf)
    imeanf = np.mean(np.array(isimilaritiesf))
    alignf = tsimilaritiesf[0] - tmeanf - imeanf

    # for moddel g
    wholetextembg = []
    tsimilaritiesg = []
    isimilaritiesg = []
    imageembg = iencoderg(images).reshape(-1).cpu().detach().numpy()
    textembg = tencoderg(open_clip.tokenize(captions[0])).reshape(-1).cpu().detach().numpy()
    wholetextembg.append(textembg)
    wholetextembg.append(errortextembg)
    for tvector in wholetextembg:
        tsimilarityg = 1 - cosine(imageembg, tvector)
        tsimilaritiesg.append(tsimilarityg)
    tmeang = np.mean(np.array(tsimilaritiesg[1:]))
    for ivector in errorimgembg:
        isimilarityg = 1 - cosine(ivector, textembg)
        isimilaritiesg.append(isimilarityg)
    imeang = np.mean(np.array(isimilaritiesg))
    aligng = tsimilaritiesg[0] - tmeang - imeang
    #clipmem
    clipmem.append(alignf-aligng)

range = max(clipmem) - min(clipmem)
clipmem = np.array(clipmem) / range
scipy.io.savemat('clipmem_np_e.mat', {'noncrop_error': clipmem})